/*
 * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
 * @Description:
 * @Version: 1.0
 * @Date: 2024-9-18 14:50:00
 * @LastEditors: dev
 * @LastEditTime: 2024-9-18 14:50:00
 */

#ifndef KTFOP_H
#define KTFOP_H

#include <cstdlib>
#include <cstdint>
#include <kblas.h>

#define KTFOP_API_PUBLIC __attribute__((visibility("default")))

#ifdef __cplusplus
extern "C" {
#endif

using ExternalLog = void (*)(int level, const char *msg);

#ifdef __cplusplus
}
#endif

namespace ktfop {
KTFOP_API_PUBLIC int SetExternalLogFunc(ExternalLog logFunc);
}

// Less Greater
namespace ktfop {
// Less
template <typename T> KTFOP_API_PUBLIC int Less(T *input0, T *input1, bool *output, size_t length);

// Less right
template <typename T> KTFOP_API_PUBLIC int Less(T *input0, T input1, bool *output, size_t length);

// Less left
template <typename T> KTFOP_API_PUBLIC int Less(T input0, T *input1, bool *output, size_t length);

// Greater
template <typename T> KTFOP_API_PUBLIC int Greater(T *input0, T *input1, bool *output, size_t length);

// Greater right
template <typename T> KTFOP_API_PUBLIC int Greater(T *input0, T input1, bool *output, size_t length);

// Greater left
template <typename T> KTFOP_API_PUBLIC int Greater(T input0, T *input1, bool *output, size_t length);
} // namespace ktfop

// FloorMod
namespace ktfop {
template <typename T> KTFOP_API_PUBLIC int FloorMod(T *input, T *mod, T *output, size_t length);

template <typename T> KTFOP_API_PUBLIC int FloorMod(T input, T *mod, T *output, size_t length);

template <typename T> KTFOP_API_PUBLIC int FloorMod(T *input, T mod, T *output, size_t length);
} // namespace ktfop

// LookupTableFind
namespace ktfop {
struct TableInfo {
    int64_t numBuckets;
    int64_t valueSize;
    int64_t emptyKey;
    int64_t deletedKey;
    int64_t *keyBucket;
    float *valueBucket;
};

KTFOP_API_PUBLIC int Find(const TableInfo &info, const int64_t *keys, float *values,
                          const float *defaultValue, int64_t length);
}

// Matmul
namespace ktfop {
template <typename T> struct MatMulParams {
    enum CBLAS_ORDER order;
    enum CBLAS_TRANSPOSE transA;
    enum CBLAS_TRANSPOSE transB;
    BLASINT m;
    BLASINT n;
    BLASINT k;
    T alpha;
    BLASINT lda;
    BLASINT ldb;
    T beta;
    BLASINT ldc;

    // 默认构造函数
    MatMulParams()
        : order(CblasColMajor), // 默认是 CblasColMajor
          transA(CblasNoTrans), // 默认是不转置
          transB(CblasNoTrans), // 默认是不转置
          m(0),
          n(0),
          k(0),
          alpha(1), // 默认 alpha 是 1
          lda(0),
          ldb(0),
          beta(0), // 默认 beta 是 0
          ldc(0)
    {}

    MatMulParams(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE transA, enum CBLAS_TRANSPOSE transB, BLASINT m, BLASINT n,
        BLASINT k, T alpha, BLASINT lda, BLASINT ldb, T beta, BLASINT ldc)
        : order(order),
          transA(transA),
          transB(transB),
          m(m),
          n(n),
          k(k),
          alpha(alpha),
          lda(lda),
          ldb(ldb),
          beta(beta),
          ldc(ldc)
    {}
};

template <typename T> KTFOP_API_PUBLIC int Matmul(T *blockA, T *blockB, T *output, MatMulParams<T> &M);
} // namespace ktfop

// Select
namespace ktfop {
template <typename T> KTFOP_API_PUBLIC int Select(bool *cond, T *thenBranch, T *elseBranch, T *output, size_t length);
}

#endif // KTFOP_H